--- title: EPSCN: Real-Time Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolution Neural keywords: fastai sidebar: home_sidebar ---
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
import cv2
import numpy as np
import re
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import PIL
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
seed = 8610
random.seed(seed)
np.random.seed(seed)
# https://github.com/nyk510/srcnn-pytorch
test_input=torch.ones(1, 1, 64, 64)
g = EPSCN(4)
test_input = test_input.cuda()
g = g.cuda()
out = g(test_input)
print(out.size())
train_hr = div2k_train_hr_crop_256
in_size = 64
out_size = 256
scale = 4
bs = 10
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, convert_mode='YCbCr', bs=bs, seed=seed)
print(data)
data.show_batch(cmap='gray')
model = EPSCN(upscale=4)
loss_func = MSELossFlat()
metrics = [m_psnr, m_ssim]
learn = Learner(data, model, loss_func=loss_func, metrics=metrics)
learn.path = Path('.')
model_name = model.__class__.__name__
lr_find(learn)
learn.recorder.plot(suggestion=True)
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
learn.show_results(cmap='gray')
test_hr = set14_hr
il_test_x = ImageImageList.from_folder(test_hr, convert_mode='YCbCr', after_open=partial(after_open_image, size=in_size, scale=4, luminance=True))
il_test_y = ImageImageList.from_folder(test_hr, convert_mode='YCbCr', after_open=partial(after_open_image, size=out_size, luminance=True))
il_test_x_up = ImageImageList.from_folder(test_hr, convert_mode='YCbCr', after_open=partial(after_open_image, size=out_size, scale=4, sizeup=True, luminance=True))
sr_test_upscale(learn, il_test_x, il_test_y, il_test_x_up, model_name, cmap='gray')
# Official: bicubic PSNR:25.99, SSIM:0.7486
model
learn.summary()